/*
    Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Definition of the nbody_data_grid and related classes. \ingroup makegrid

#ifndef __sphgrid__
#define __sphgrid__

#include "grid.h"
#include "blitz/array.h"
#include "boost/shared_ptr.hpp"
#include "boost/thread/mutex.hpp"
#include <vector>
#include <algorithm> 
#include "sphparticle.h"
#include "counter.h"
#include "refinement_accuracy_data.h"

namespace mcrx {
  class nbody_data;
  class nbody_data_cell;
  class tolerance_checker;
  class nbody_data_grid;
  class check_particle_overlap_p;
  template <typename> class particle;
  class Snapshot; // forward decl
}

namespace CCfits {
  class FITS;
  class HDU;
}

int main (int, char**);

/** Contains quantities about the particles output by sfrhist.  These
 quantities are then calculated for each grid cell when the adaptive
 grid is constructed.  \ingroup makegrid */
class mcrx::nbody_data {
  friend class nbody_data_grid;
  friend int ::main(int, char**);
public:
  typedef mcrx::T_float T_float;
protected:
  T_float m_g_;   ///< Gas mass.
  vec3d p_g_;     ///< Gas momentum
  T_float L_bol_; ///< Bolometric luminosity.
  T_float SFR;    ///< Star-formation rate.
  T_float m_metals_; ///< Mass in metals (in the gas).
  T_float m_s_; ///< Stellar mass.
  vec3d p_s_;  ///< Stellar momentum
  T_float m_m_s;  ///< Mass in metals (in the stars).
  T_float age_m; ///< Mass-weighted stellar age.
  T_float age_l; ///< Luminosity-weighted stellar age.
  T_float gas_temp_m; ///< Mass-weighted gas temperature.
  T_float gas_teff_m; ///< Mass-weighted gas effective temperature.

  array_1 L_lambda; ///< Spectral energy distribution of stellar emission.

public:
  nbody_data (): m_g_ (0), p_g_(0,0,0), L_bol_ (0), SFR (0), m_metals_ (0), m_s_(0),
		 p_s_(0,0,0),
		 m_m_s(0),age_m(0),age_l(0),gas_temp_m(0),gas_teff_m(0),
		 L_lambda (make_thread_local_copy(array_1 ())) {};
  /** Note that the constructor takes temperature*mass, not the temperature itself. */
  nbody_data (T_float m_gas, const vec3d& p_gas, T_float Lb, T_float s, T_float met,
	      T_float m_star, const vec3d& p_star, T_float starmet, T_float agem, 
	      T_float agel,
	      T_float mtemp, T_float mteff,
	      const array_1& Ll = array_1 ()):
    m_g_ (m_gas), p_g_(p_gas), L_bol_ (Lb), SFR (s), m_metals_ (met),
    m_s_(m_star), p_s_(p_star), m_m_s(starmet),age_m(agem),age_l(agel),
    gas_temp_m(mtemp), gas_teff_m(mteff), 
    L_lambda (make_thread_local_copy(Ll)) {};
  // copy constructor does NOT have reference semantics
  nbody_data (const nbody_data& d): m_g_ (d.m_g_), p_g_(d.p_g_), L_bol_(d.L_bol_),
				    SFR (d.SFR), m_metals_ (d.m_metals_), 
				    m_s_(d.m_s_), p_s_(d.p_s_), m_m_s(d.m_m_s),
				    age_m(d.age_m), age_l(d.age_l),
				    gas_temp_m(d.gas_temp_m),
				    gas_teff_m(d.gas_teff_m),
				    /// \todo no thread local copy here??
				    L_lambda (d.L_lambda.size()) {
    L_lambda = d.L_lambda;};

  ///@{ 
  /// \name Access operators
  T_float m_g() const { return m_g_;};
  T_float m_metals() const { return m_metals_;};
  const vec3d& p_g() const {
    assert(all(p_g_==0) || (sqrt(dot(p_g_,p_g_))/m_g_<1e-4));
    return p_g_;};
  T_float L_bol() const { return L_bol_;};
  T_float gas_temp() const { return gas_temp_m/m_g_;};
  T_float gas_teff() const { return gas_teff_m/m_g_;};
  ///@}

  ///@{
  /// \name Arithmetic operators
  /// These operators are used when adding quantities in particles together.
  
  /// Assignment operator, resizes L_lambda as necessary.
  nbody_data& operator= (const nbody_data&);
  nbody_data& operator+= (const nbody_data& rhs) {
    m_g_+= rhs.m_g_; p_g_+= rhs.p_g_; L_bol_+= rhs.L_bol_; SFR += rhs.SFR;
    assert(all(p_g_==0) || all(p_g_!=p_g_) || (sqrt(dot(p_g_,p_g_))/m_g_<1e-4));
    m_metals_ += rhs.m_metals_; 
    m_s_ += rhs.m_s_; p_s_+= rhs.p_s_; m_m_s += rhs.m_m_s; age_m += rhs.age_m;
    age_l += rhs.age_l; 
    gas_temp_m += rhs.gas_temp_m; gas_teff_m += rhs.gas_teff_m;
    L_lambda+= rhs.L_lambda; return *this;}
  /** Multiplication operator does element-wise multiplication.  Note
      that this operator does NOT operate on the vector members,
      i.e. L_lambda and momenta, because that doesn't make sense
      given how the operator is used.  */
  nbody_data& operator*= (const nbody_data& rhs) {
    // This operator does NOT operate on L_lambda 
    m_g_*= rhs.m_g_; L_bol_*= rhs.L_bol_; SFR*= rhs.SFR;
    m_metals_*= rhs.m_metals_; 
    m_s_ *= rhs.m_s_; m_m_s *= rhs.m_m_s; age_m *= rhs.age_m;
    age_l *= rhs.age_l; 
    gas_temp_m *= rhs.gas_temp_m; gas_teff_m *= rhs.gas_teff_m; 

    //set this to NaN to ensure it's not used
    p_g_= blitz::quiet_NaN(T_float());

    return *this;} 

  /// See operator*= for an important note about how this works.
  nbody_data operator*(const nbody_data& rhs) const {
    return nbody_data (*this)*= rhs;}
  nbody_data& operator*= (T_float rhs) {
    m_g_*= rhs; p_g_*= rhs; L_bol_*= rhs; SFR*= rhs; m_metals_*= rhs;
    m_s_ *= rhs; p_s_ *= rhs; m_m_s *= rhs; age_m *= rhs;
    age_l *= rhs; gas_temp_m *= rhs; gas_teff_m *= rhs;
    L_lambda*= rhs; 

    //set this to NaN to ensure it's not used
    p_g_= blitz::quiet_NaN(T_float());

    return *this;} 

  nbody_data operator*(T_float rhs) const {
    return nbody_data (*this)*= rhs;}
  /** A "floating multiply and add"-operator.  This could have been
      implemented using some fancy operator Magic, but not worth
      it. */
  void add_to (/// The object to be added to this.
	       const nbody_data& d,
	       /// The scalar to multiply d with before adding.
	       T_float f,
	       /** A volume necessary in case there were extensive
                   quantities, which currently there aren't */
	       T_float vol) {
    assert (d.m_g_ >= 0);
    assert (d.L_bol_ >= 0);
    assert (m_g_ >= 0);
    assert (L_bol_ >= 0);
    m_g_ += d.m_g_*f; p_g_ += d.p_g_*f; L_bol_+= d.L_bol_*f; SFR += d.SFR*f;
    assert(all(p_g_==0) || (sqrt(dot(p_g_,p_g_))/m_g_<1e-4));
    m_metals_ += d.m_metals_*f;
    m_s_ += d.m_s_*f; p_s_ += d.p_s_*f; m_m_s += d.m_m_s*f; age_m += d.age_m*f;
    age_l += d.age_l*f; 
    gas_temp_m += d.gas_temp_m*f;
    gas_teff_m += d.gas_teff_m*f;
    if (d.L_lambda.size() > 0) L_lambda += d.L_lambda*f;}
  ///@}
};

/** This class is basically just nbody_data, but with the added
    (static) functions used when creating the grid refinements.  
    \ingroup makegrid */
class mcrx::nbody_data_cell: public nbody_data {
public:
  nbody_data_cell() : nbody_data() {};
  nbody_data_cell (T_float m_gas, const vec3d& p_gas, T_float Lb, T_float s, T_float met,
		   T_float m_star, const vec3d& p_star, T_float starmet, 
		   T_float agem, T_float agel,
		   T_float mtemp, T_float mteff,
		   const array_1& Ll = array_1 ()):
    nbody_data (m_gas, p_gas, Lb, s, met, m_star, p_star, starmet, agem, agel, 
		mtemp, mteff, Ll) {};
  /** A "static downcast", converting an nbody_data object into
      nbody_data_cell.  This is perfectly safe since there are no
      extra data members in nbody_data_cell.  */
  nbody_data_cell (const nbody_data& d):
    nbody_data (d) {};

  /** Calculates the proper quantities in a cell which is the
      unification of the number of smaller cells.  Because all
      quantities are extensive, the unification cell corresponding to
      a group of cells is simply the sum of their quantities.  If
      there were extensive quantities, it would be the average, but
      there are none.  */
  static nbody_data_cell unification (const nbody_data_cell& sum, int) {
    return sum;};
};


/** This class is a predicate that returns whether a cell can be
    unified. It keeps tolerances and data necessary to test
    unification. */
class mcrx::tolerance_checker {
private:
  typedef adaptive_grid<nbody_data_cell>::T_cell T_cell;

  const nbody_data_cell tolerance_;
  const nbody_data_cell tolerance_absolute_;
  const T_float gas_metallicity_;
  const T_float max_metal_column_;
  T_float L_bol_tot_;

public: 
  tolerance_checker(const nbody_data_cell& tolerance, 
		    const nbody_data_cell& tol_absolute, 
		    T_float gm, T_float maxcol) :
    tolerance_(tolerance), tolerance_absolute_(tol_absolute),
    gas_metallicity_(gm), max_metal_column_(maxcol) {};

  void set_L_bol (T_float l) {L_bol_tot_=l;};
  const nbody_data_cell& tolerance() const { return tolerance_;};
  const nbody_data_cell& tolerance_absolute() const {
    return tolerance_absolute_;};
  T_float max_metal_column() const { return max_metal_column_;};

  /** Tests if a collection of cells can be unified.  The function
      returns true if the collection of cells represented by sum and
      sumsq have a lower fractional deviation than tolerance or if the
      absolute standard deviation is a smaller than absolute for all
      quantities, and the column density of metals is smaller than the
      specified maximum. This is used for ensuring that the grid
      resolves a specific optical depth which is important for the
      temperature calculations.  */
  bool unify_cell_p (const nbody_data_cell& sum,
		     const nbody_data_cell& sumsq,
		     int n, const T_cell::T_cell_tracker& c) const;

  /** Tests if a cell should be refined because it has too large metal
      column density, based on the particles that are inside it. */
  bool refine_cell_p (const check_particle_overlap_p& cpo,
		      const T_cell::T_cell_tracker& c) const;
  void print_tolerances() const;
};


/** A grid containing data from the nbody particles, used for creating
    the hierarchical grid. \ingroup makegrid */
class mcrx::nbody_data_grid:
  public adaptive_grid <nbody_data_cell>
{
public:
  typedef adaptive_grid <nbody_data_cell> T_base;
  typedef refinement_accuracy_data<nbody_data_cell> T_racc;
  friend int ::main(int, char**);
  typedef particle<nbody_data> T_particle;
private:
  ///@{
  /// \name Refinement variables
  /// These members control the behavior of the adaptive refinement.
  
  /** Factor determining when refinement is truncated based on the
      size of the particles in the cell.  If the cell size is smaller
      than size_factor times better radius of the smallest contained
      particle, no further refinement is done.  */
  T_float size_factor;
  /// Maximum refinement level.
  int max_level;

  /// Predicate object for checking whether cells can be unified.
  std::auto_ptr<tolerance_checker> tol_checker_;

  /// "Default" cell data, used if there are no particles in the cell.
  T_data data_zero;
  counter ctr;
  ///@}

  ///@{
  /// \name Units 
  /// The units are simply propagated from snapshot file to output file.
  std::string mass_unit;
  std::string time_unit;
  std::string length_unit; 
  std::string L_bol_unit;
  std::string SFR_unit;
  std::string temp_unit;
  std::string L_lambda_unit;
  ///@}
  
  /// Total grid quantities, summed over all cells
  nbody_data_cell total_quantities;

  ///@{
  /// \name Threading data structures
  /// These variables are used to manage the multithreaded refinement.

  mutable int n_threads; ///< Number of threads to use.
  /** Number of levels to refine before dividing the work among
      threads.  This is a load-balancing issue, see the
      makegrid-README.  */
  int work_chunk_levels;
  /// Mutex used to protect access to the vector of cells to process.
  mutable boost::mutex cell_stack_mutex;
  /// The vector of cells that needs to be processed.
  std::vector<T_code> cell_stack;
  /// The particles that the grid will be based on.
  std::vector<T_particle*> particle_list_;
  class thread_start; // function object used to start threads
  /// Keeps statistics about how many cells are at which refinement level.
  std::vector<int> creation_stats;
  void pop_cell_and_refine ();
  ///@}

  ///@{
  /** \name Load-balancing threading data structures 
      These variables are used to manage the multithreaded creation of
      a list of approximately load-balanced cells to work on.  (Yeah,
      it's weird.)  The cell_stack_mutex is also used for these threads. */

  /// Keeps track of the data for the load-balancing. \ingroup makegrid
  class balance_queue_data {
  public:
    typedef mcrx::nbody_data_grid::T_cell T_cell;
    
    T_code cell;///< Pointer to the cell
    /// List of particles that are in this cell 
    boost::shared_ptr<std::vector<T_particle*> > particle_list_;
    balance_queue_data () {};
    balance_queue_data (T_code c, 
			boost::shared_ptr<std::vector<T_particle*> > pl):
      cell (c), particle_list_ (pl) {};
  };
  /// The vector of cells that have been balanced 
  std::vector<balance_queue_data> balance_queue;
  class balance_thread_start;
  void process_balance_queue_cell ();
  ///@}
  

  ///@{
  /// \name Functions for grid building and refinement
  
  // this function builds the grid using the particles in particle_list
  void build_grid (); 
  void create_balanced_queue ();
  T_racc recursive_refine (const T_cell_tracker& c,
			   const std::vector<T_particle*>& pl,
			   std::vector<int>&);
  T_racc recursive_refine_body (const T_cell_tracker& c, 
				const std::vector<T_particle*>& pl,
				std::vector<int>&);
  T_racc project_particles (const T_cell_tracker& c,
			    const std::vector<T_particle*> & pl);
  void unrefine_if_possible (const T_cell_tracker& c,
			     T_racc& racc);
  ///@}
  
public:
  nbody_data_grid (const vec3d & mi, const vec3d & ma):
    adaptive_grid<nbody_data_cell> (mi, ma),
    ctr (1), n_threads (0), work_chunk_levels (0) {};
  /** Construct a nbody_data_grid reading the structure from the
      specified GRIDSTRUCTURE HDU, while also loading the length
      unit. */
  nbody_data_grid (CCfits::ExtHDU& input, 
		   const vec3d& translate_origin=vec3d(0,0,0));

  bool load_snapshot (const mcrx::Snapshot&,
		      const int ml, const T_float size_fudge,
		      const tolerance_checker&,
		      CCfits::HDU*info = 0, bool use_SED = true);
  void save_data (CCfits::FITS&, const std::string&, const std::string&) const;
  /** Loads grid cell data from a FITS GRIDDATA HDU into an existing
      grid structure.  */
  void load_data (CCfits::ExtHDU&);
  void use_threads (int n) const {n_threads = n;};
  void set_work_chunk (int n) {work_chunk_levels = n;};

  /// Saves the grid structure to FITS file. 
  // is this an override or hiding?  I think it's hiding.
  void save_structure (CCfits::FITS& file, const std::string& hdu,
		       bool save_codes) const {
    T_base::save_structure (file, hdu, length_unit, save_codes);};
};

/** Predicate used to copy particles that overlap with a grid cell,
    while keeping track of the minimum size of the particles. 
    \ingroup makegrid */
class mcrx::check_particle_overlap_p :
  public std::unary_function<mcrx::nbody_data_grid::T_particle*, bool> {
private:
  friend int main(int, char**);
  typedef mcrx::nbody_data_grid::T_float T_float;
  typedef mcrx::nbody_data_grid::T_particle T_particle;
  typedef mcrx::nbody_data_grid::T_cell T_cell;
  
  vec3d cellmin_, cellmax_;

  T_float gas_metallicity_;
  // Members that return data MUST be shared_ptrs because the object
  // is copied into the algorithm.
  boost::shared_ptr<T_float> min_size_;
  boost::shared_ptr<T_float> max_rho_;

public:
  template <typename T_grid_iterator>
  check_particle_overlap_p (const T_grid_iterator& cc, T_float gm=0) :
    cellmin_(cc.getmin()), cellmax_(cc.getmax()), gas_metallicity_ (gm), 
    min_size_ (new T_float (blitz::huge (T_float () ))),
    max_rho_(new T_float()) {};

  /// Returns the size of the smallest particles seen so far.
  T_float min_size () const {return *min_size_;};
  T_float max_rho () const {return *max_rho_;};
  /** Returns true if the particle overlaps with the cell.  Also keeps
      track of the minimum size and the maximum density of the
      overlapping particles. */
  bool operator () (const T_particle*const p) {
    const bool o =p->overlap(cellmin_, cellmax_);
    if (o) {
      const T_float rho_est = (p->data().m_metals() + 
			       p->data().m_g()*gas_metallicity_)/
	(p->radius()*p->radius()*p->radius());
      *max_rho_ = std::max(*max_rho_, rho_est );
      *min_size_ = std::min (*min_size_, p->radius());
    }
    return o;
  };
};

/** Functor used by boost threads to start the refinement threads.  It simply
    calls pop_cell_and_refine () for the grid. \ingroup makegrid */
class mcrx::nbody_data_grid::thread_start {
private:
  nbody_data_grid*self;
public:
  thread_start (nbody_data_grid*g): self (g) {};
  void operator () () {
    self->pop_cell_and_refine ();
  };
};

/** Functor used by boost threads to start the load balancing threads.  It
    simply calls process_balance_queue_cell () for the grid.
    \ingroup makegrid */
class mcrx::nbody_data_grid::balance_thread_start {
private:
  nbody_data_grid*self;
public:
  balance_thread_start (nbody_data_grid*g): self (g) {};
  void operator () () {
    self->process_balance_queue_cell ();
  };
};


#endif

